import json
import logging
import collections
import tqdm
import csv
from llm import *
import argparse

def parse_args():
    parser = argparse.ArgumentParser(description="Run direct generation for various datasets and models.")
    
    parser.add_argument(
        '--dataset_name', 
        type=str, 
        required=True, 
        choices=['test', 'hi_tom', 'explore', 'tomi', 'gpqa', 'math500', 'aime', 'amc', 'livecode', 'nq', 'triviaqa', 'hotpotqa', '2wiki', 'musique', 'bamboogle', 'medmcqa', 'pubhealth'],
        help="Name of the dataset to use."
    )

    parser.add_argument(
        '--model_name', 
        type=str, 
        required=True,
        help="Name of the selected model."
    )
    
    parser.add_argument(
        '--temperature', 
        type=float, 
        default=0.0, 
        help="Sampling temperature."
    )
    
    parser.add_argument(
        '--max_tokens', 
        type=int, 
        default=4096, 
        help="Maximum number of tokens to generate. If not set, defaults based on the model and dataset."
    )

    parser.add_argument('--num_probs', '-n', type=int, default=300)
    
    return parser.parse_args()

def split_story_by_lines(story_text):
    """
    将故事文本按行分割成多个短句，返回句子列表
    """
    # 按换行符分割文本
    lines = story_text.strip().split('\n')
    
    # 过滤掉空行，创建句子列表
    sentence_list = [line.strip() for line in lines if line.strip()]
    
    return sentence_list

def most_common(lst):
    data = collections.Counter(lst)
    return data.most_common(1)[0][0]

def read_txt(filename):
    with open(filename, 'r') as f:
        s = f.read()
    return s

connect_str_1 = """. Wait, perhaps I'm over thinking. Here I can use some of the following information: 
<information>1. An agent can witness everything and every movement only before exiting a room.
2. An agent A can infer another agent B’s mental state only if A and B have been in the same room, or have had private or public interactions.
"""
connect_str_2 = """</information>
    Based on these, I may be able to get answers faster.

**Step-by-Step Explanation:**
"""

def hi_tom():
    args = parse_args()
    model_realname = args.model_name
    model_realname = model_realname.replace("/", "-")
    logging.basicConfig(
    filename=f"log_{model_realname}_HiToM_prompt.log",      # 日志文件名称
    level=logging.INFO,         # 设置日志级别，可选DEBUG, INFO, WARNING, ERROR, CRITICAL
    format='%(asctime)s - %(levelname)s - %(message)s'
    )
    dataset_name = args.dataset_name
    temperature = args.temperature
    
    # Print stuff
    print("\n-----------------------------")
    print("    GENERATE HiToM PROMPT     ")
    print("-----------------------------")
    print(f"EVAL MODEL: {args.model_name}")
    print(f"DATA: {args.dataset_name}")
    print(f"N = {args.num_probs}")
    print("-----------------------------\n")
    logging.info("-----------------------------")
    logging.info("    GENERATE HiToM PROMPT      ")
    logging.info("-----------------------------")
    logging.info(f"EVAL MODEL: {args.model_name}")
    logging.info(f"DATA: {args.dataset_name}")
    logging.info(f"N = {args.num_probs}")
    
    # Paths to datasets
    if dataset_name == 'hi_tom':
        data_path = f'./data/Hi-ToM_data.json'
    elif dataset_name == 'test':
        data_path = f'./data/hi-tom_test.json'
    else:
        raise ValueError(f"Unsupported dataset_name: {dataset_name}")

    output_dir = f'./prompts/{dataset_name}/{model_realname}.json'
    
    # Load data
    with open(output_dir, "w") as f_out:
        with open(data_path) as f_in:
            data = json.load(f_in)
            for item in tqdm.tqdm(data["data"]):
                sample_id = item["sample_id"]
                story = item["story"]
                question = item["question"]
                sentence = ''
                final = "story:\n" + story + "\nquestion: " + question
                sentence_list = split_story_by_lines(story)
                progen = DS_Qwen(model_realname, temperature=temperature)
                if (sample_id + 1) > args.num_probs:
                    break
                info = {
                    "sample_id": sample_id
                }
                for i, prompt in enumerate(sentence_list):
                    result = str(progen.getOutput(prompt, sentence))
                    sentence = result
                    result = connect_str_1 + sentence + connect_str_2
                    logging.info(f"i:{i}")
                    logging.info(f"prompt:{sentence}")
                    info[f"prompt_{i}"] = result
                json.dump(info, f_out, ensure_ascii=False, indent=4)
                f_out.write(",\n")

def explore():
    args = parse_args()
    model_realname = args.model_name
    model_realname = model_realname.replace("/", "-")
    logging.basicConfig(
    filename=f"log_{model_realname}_ExploreToM_prompt.log",      # 日志文件名称
    level=logging.INFO,         # 设置日志级别，可选DEBUG, INFO, WARNING, ERROR, CRITICAL
    format='%(asctime)s - %(levelname)s - %(message)s'
    )
    dataset_name = args.dataset_name
    temperature = args.temperature
    
    # Print stuff
    print("\n-----------------------------")
    print("    GENERATE ExploreToM PROMPT     ")
    print("-----------------------------")
    print(f"EVAL MODEL: {args.model_name}")
    print(f"DATA: {args.dataset_name}")
    print(f"N = {args.num_probs}")
    print("-----------------------------\n")
    logging.info("-----------------------------")
    logging.info("    GENERATE ExploreToM PROMPT      ")
    logging.info("-----------------------------")
    logging.info(f"EVAL MODEL: {args.model_name}")
    logging.info(f"DATA: {args.dataset_name}")
    logging.info(f"N = {args.num_probs}")
    
    # Paths to datasets
    if dataset_name == 'explore':
        data_path = f'./data/ExploreToM-data-sample.csv'
    elif dataset_name == 'test':
        data_path = f'./data/hi-tom_test.json'
    else:
        raise ValueError(f"Unsupported dataset_name: {dataset_name}")

    output_dir = f'./prompts/{dataset_name}/{model_realname}.json'
    
    # Load data
    with open(output_dir, "w") as f_out:
        with open(data_path, 'r') as f_in:
            i = 0
            csc_reader = csv.reader(f_in)
            next(csc_reader)
            for row in csc_reader:
                story = row[0]
                question = row[2]
                final = "story:\n" + story + "\nquestion: " + question
                i += 1
                sentence = ''
                sentence_list = split_story_by_lines(story)
                progen = DS_Qwen(model_realname, temperature=temperature)
                if i > args.num_probs:
                    break
                info = {
                    "sample_id": i
                }
                for j, prompt in enumerate(sentence_list):
                    result = str(progen.getOutput(prompt, sentence))
                    sentence = result
                    result = connect_str_1 + sentence + connect_str_2
                    logging.info(f"action:{prompt}")
                    logging.info(f"j:{j}")
                    logging.info(f"prompt:{sentence}")
                    info[f"prompt_{j}"] = result
                json.dump(info, f_out, ensure_ascii=False, indent=4)
                f_out.write(",\n")
    
def tomi():
    args = parse_args()
    model_realname = args.model_name
    model_realname = model_realname.replace("/", "-")
    logging.basicConfig(
    filename=f"log_{model_realname}_ToMi_prompt.log",      # 日志文件名称
    level=logging.INFO,         # 设置日志级别，可选DEBUG, INFO, WARNING, ERROR, CRITICAL
    format='%(asctime)s - %(levelname)s - %(message)s'
    )
    dataset_name = args.dataset_name
    temperature = args.temperature
    
    # Print stuff
    print("\n-----------------------------")
    print("    GENERATE ToMi PROMPT     ")
    print("-----------------------------")
    print(f"EVAL MODEL: {args.model_name}")
    print(f"DATA: {args.dataset_name}")
    print(f"N = {args.num_probs}")
    print("-----------------------------\n")
    logging.info("-----------------------------")
    logging.info("    GENERATE ToMi PROMPT      ")
    logging.info("-----------------------------")
    logging.info(f"EVAL MODEL: {args.model_name}")
    logging.info(f"DATA: {args.dataset_name}")
    logging.info(f"N = {args.num_probs}")
    
    # Paths to datasets
    if dataset_name == 'tomi':
        data_path = f'./data/test_balanced.jsonl'
    elif dataset_name == 'test':
        data_path = f'./data/hi-tom_test.json'
    else:
        raise ValueError(f"Unsupported dataset_name: {dataset_name}")

    output_dir = f'./prompts/{dataset_name}/{model_realname}.json'
    
    # Load data
    with open(output_dir, "w") as f_out:
        with open(data_path) as f_in:
            i = 0
            for index, line in tqdm.tqdm(enumerate(f_in), total=min(len(read_txt(data_path).split('\n')), args.num_probs)):
                # Parse each line as JSON
                fields = json.loads(line.strip())
                
                # extract fields from JSON
                story, question = fields["story"], fields["question"]
                i += 1
                sentence = ''
                sentence_list = split_story_by_lines(story)
                progen = DS_Qwen(model_realname, temperature=temperature)
                if i > args.num_probs:
                    break
                info = {
                    "sample_id": i
                }
                for j, prompt in enumerate(sentence_list):
                    result = str(progen.getOutput(prompt, sentence))
                    sentence = result
                    result = connect_str_1 + sentence + connect_str_2
                    logging.info(f"j:{j}")
                    logging.info(f"prompt:{sentence}")
                    info[f"prompt_{j}"] = result
                json.dump(info, f_out, ensure_ascii=False, indent=4)
                f_out.write(",\n")

if __name__ == "__main__":
    hi_tom()
    # explore()
    # tomi()